{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic usage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*`skorch`* is designed to maximize interoperability between `sklearn` and `pytorch`. The aim is to keep 99% of the flexibility of `pytorch` while being able to leverage most features of `sklearn`. Below, we show the basic usage of `skorch` and how it can be combined with `sklearn`.\n",
    "\n",
    "<table align=\"left\"><td>\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Basic_Usage.ipynb\">\n",
    "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
    "</td><td>\n",
    "<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/Basic_Usage.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows you how to use the basic functionality of `skorch`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of contents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* [Definition of the pytorch module](#Definition-of-the-pytorch-module)\n",
    "* [Training a classifier](#Training-a-classifier-and-making-predictions)\n",
    "  * [Dataset](#A-toy-binary-classification-task)\n",
    "  * [pytorch module](#Definition-of-the-pytorch-classification-module)\n",
    "  * [Model training](#Defining-and-training-the-neural-net-classifier)\n",
    "  * [Inference](#Making-predictions,-classification)\n",
    "* [Training a regressor](#Training-a-regressor)\n",
    "  * [Dataset](#A-toy-regression-task)\n",
    "  * [pytorch module](#Definition-of-the-pytorch-regression-module)\n",
    "  * [Model training](#Defining-and-training-the-neural-net-regressor)\n",
    "  * [Inference](#Making-predictions,-regression)\n",
    "* [Saving and loading a model](#Saving-and-loading-a-model)\n",
    "  * [Whole model](#Saving-the-whole-model)\n",
    "  * [Only parameters](#Saving-only-the-model-parameters)\n",
    "* [Usage with an sklearn Pipeline](#Usage-with-an-sklearn-Pipeline)\n",
    "* [Callbacks](#Callbacks)\n",
    "* [Grid search](#Usage-with-sklearn-GridSearchCV)\n",
    "  * [Special prefixes](#Special-prefixes)\n",
    "  * [Performing a grid search](#Performing-a-grid-search)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "! [ ! -z \"$COLAB_GPU\" ] && pip install torch skorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "torch.manual_seed(0);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a classifier and making predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A toy binary classification task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load a toy classification task from `sklearn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import make_classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n",
    "X = X.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 20), (1000,), 0.5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape, y.shape, y.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of the `pytorch` classification `module`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_units=10,\n",
    "            nonlin=F.relu,\n",
    "            dropout=0.5,\n",
    "    ):\n",
    "        super(ClassifierModule, self).__init__()\n",
    "        self.num_units = num_units\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = dropout\n",
    "\n",
    "        self.dense0 = nn.Linear(20, num_units)\n",
    "        self.nonlin = nonlin\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.dense1 = nn.Linear(num_units, 10)\n",
    "        self.output = nn.Linear(10, 2)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = self.dropout(X)\n",
    "        X = F.relu(self.dense1(X))\n",
    "        X = F.softmax(self.output(X), dim=-1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining and training the neural net classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use `NeuralNetClassifier` because we're dealing with a classifcation task. The first argument should be the `pytorch module`. As additional arguments, we pass the number of epochs and the learning rate (`lr`), but those are optional.\n",
    "\n",
    "*Note*: To use the CUDA backend, pass `device='cuda'` as an additional argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch import NeuralNetClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "#     device='cuda',  # uncomment this to train with CUDA\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`. By default, `NeuralNetClassifier` makes a `StratifiedKFold` split on the data (80/20) to track the validation loss. This is shown, as well as the train loss and the accuracy on the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Automatic pdb calling has been turned ON\n"
     ]
    }
   ],
   "source": [
    "pdb on"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.6905\u001b[0m       \u001b[32m0.6150\u001b[0m        \u001b[35m0.6749\u001b[0m  0.0235\n",
      "      2        \u001b[36m0.6648\u001b[0m       \u001b[32m0.6450\u001b[0m        \u001b[35m0.6633\u001b[0m  0.0213\n",
      "      3        \u001b[36m0.6619\u001b[0m       \u001b[32m0.6750\u001b[0m        \u001b[35m0.6533\u001b[0m  0.0219\n",
      "      4        \u001b[36m0.6429\u001b[0m       \u001b[32m0.6800\u001b[0m        \u001b[35m0.6399\u001b[0m  0.0207\n",
      "      5        \u001b[36m0.6307\u001b[0m       \u001b[32m0.6950\u001b[0m        \u001b[35m0.6254\u001b[0m  0.0192\n",
      "      6        \u001b[36m0.6291\u001b[0m       \u001b[32m0.7000\u001b[0m        \u001b[35m0.6134\u001b[0m  0.0202\n",
      "      7        \u001b[36m0.6102\u001b[0m       \u001b[32m0.7100\u001b[0m        \u001b[35m0.6033\u001b[0m  0.0220\n",
      "      8        \u001b[36m0.6050\u001b[0m       0.7000        \u001b[35m0.5931\u001b[0m  0.0210\n",
      "      9        \u001b[36m0.5966\u001b[0m       0.7000        \u001b[35m0.5844\u001b[0m  0.0217\n",
      "     10        \u001b[36m0.5636\u001b[0m       0.7100        \u001b[35m0.5689\u001b[0m  0.0226\n",
      "     11        0.5757       \u001b[32m0.7200\u001b[0m        \u001b[35m0.5628\u001b[0m  0.0196\n",
      "     12        0.5757       0.7200        \u001b[35m0.5520\u001b[0m  0.0190\n",
      "     13        \u001b[36m0.5559\u001b[0m       \u001b[32m0.7300\u001b[0m        \u001b[35m0.5459\u001b[0m  0.0218\n",
      "     14        \u001b[36m0.5541\u001b[0m       0.7300        \u001b[35m0.5424\u001b[0m  0.0206\n",
      "     15        0.5659       \u001b[32m0.7350\u001b[0m        \u001b[35m0.5378\u001b[0m  0.0215\n",
      "     16        \u001b[36m0.5364\u001b[0m       0.7350        \u001b[35m0.5322\u001b[0m  0.0192\n",
      "     17        0.5456       0.7300        \u001b[35m0.5239\u001b[0m  0.0221\n",
      "     18        0.5476       \u001b[32m0.7450\u001b[0m        0.5260  0.0188\n",
      "     19        0.5499       \u001b[32m0.7500\u001b[0m        0.5249  0.0213\n",
      "     20        \u001b[36m0.5273\u001b[0m       0.7350        0.5251  0.0206\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Making predictions, classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = net.predict(X[:5])\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.5349464 , 0.46505365],\n",
       "       [0.8685093 , 0.1314907 ],\n",
       "       [0.6860039 , 0.31399614],\n",
       "       [0.9126012 , 0.08739878],\n",
       "       [0.69675475, 0.30324525]], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_proba = net.predict_proba(X[:5])\n",
    "y_proba"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a regressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A toy regression task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)\n",
    "X_regr = X_regr.astype(np.float32)\n",
    "y_regr = y_regr.astype(np.float32) / 100\n",
    "y_regr = y_regr.reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 20), (1000, 1), -6.4901485, 6.154505)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Note*: Regression currently requires the target to be 2-dimensional, hence the need to reshape. This should be fixed with an upcoming version of pytorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of the `pytorch` regression `module`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, define a vanilla neural network with two hidden layers. The main difference is that the output layer only has one unit and does not apply a softmax nonlinearity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RegressorModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_units=10,\n",
    "            nonlin=F.relu,\n",
    "    ):\n",
    "        super(RegressorModule, self).__init__()\n",
    "        self.num_units = num_units\n",
    "        self.nonlin = nonlin\n",
    "\n",
    "        self.dense0 = nn.Linear(20, num_units)\n",
    "        self.nonlin = nonlin\n",
    "        self.dense1 = nn.Linear(num_units, 10)\n",
    "        self.output = nn.Linear(10, 1)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = F.relu(self.dense1(X))\n",
    "        X = self.output(X)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining and training the neural net regressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training a regressor is almost the same as training a classifier. Mainly, we use `NeuralNetRegressor` instead of `NeuralNetClassifier` (this is the same terminology as in `sklearn`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch import NeuralNetRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_regr = NeuralNetRegressor(\n",
    "    RegressorModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "#     device='cuda',  # uncomment this to train with CUDA\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_loss     dur\n",
      "-------  ------------  ------------  ------\n",
      "      1        \u001b[36m4.4168\u001b[0m        \u001b[32m3.0788\u001b[0m  0.0292\n",
      "      2        \u001b[36m2.0120\u001b[0m        \u001b[32m0.4565\u001b[0m  0.0270\n",
      "      3        \u001b[36m0.3343\u001b[0m        \u001b[32m0.2262\u001b[0m  0.0263\n",
      "      4        \u001b[36m0.1851\u001b[0m        \u001b[32m0.2223\u001b[0m  0.0257\n",
      "      5        \u001b[36m0.1491\u001b[0m        \u001b[32m0.1068\u001b[0m  0.0242\n",
      "      6        \u001b[36m0.0946\u001b[0m        0.1207  0.0263\n",
      "      7        \u001b[36m0.0739\u001b[0m        \u001b[32m0.0663\u001b[0m  0.0290\n",
      "      8        \u001b[36m0.0554\u001b[0m        0.0706  0.0298\n",
      "      9        \u001b[36m0.0437\u001b[0m        \u001b[32m0.0461\u001b[0m  0.0337\n",
      "     10        \u001b[36m0.0372\u001b[0m        0.0469  0.0273\n",
      "     11        \u001b[36m0.0291\u001b[0m        \u001b[32m0.0343\u001b[0m  0.0263\n",
      "     12        \u001b[36m0.0270\u001b[0m        \u001b[32m0.0333\u001b[0m  0.0285\n",
      "     13        \u001b[36m0.0207\u001b[0m        \u001b[32m0.0265\u001b[0m  0.0281\n",
      "     14        \u001b[36m0.0196\u001b[0m        \u001b[32m0.0249\u001b[0m  0.0344\n",
      "     15        \u001b[36m0.0152\u001b[0m        \u001b[32m0.0215\u001b[0m  0.0286\n",
      "     16        \u001b[36m0.0151\u001b[0m        \u001b[32m0.0198\u001b[0m  0.0281\n",
      "     17        \u001b[36m0.0120\u001b[0m        \u001b[32m0.0182\u001b[0m  0.0283\n",
      "     18        \u001b[36m0.0119\u001b[0m        \u001b[32m0.0167\u001b[0m  0.0266\n",
      "     19        \u001b[36m0.0100\u001b[0m        \u001b[32m0.0159\u001b[0m  0.0266\n",
      "     20        \u001b[36m0.0097\u001b[0m        \u001b[32m0.0149\u001b[0m  0.0259\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.regressor.NeuralNetRegressor'>[initialized](\n",
       "  module_=RegressorModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=1, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net_regr.fit(X_regr, y_regr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Making predictions, regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may call `predict` or `predict_proba` on the fitted model. For regressions, both methods return the same value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.4903931 ],\n",
       "       [-1.4224019 ],\n",
       "       [-0.77500594],\n",
       "       [-0.06901944],\n",
       "       [-0.3867012 ]], dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = net_regr.predict(X_regr[:5])\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving and loading a model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save and load either the whole model by using pickle or just the learned model parameters by calling `save_params` and `load_params`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving the whole model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = '/tmp/mymodel.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/thomasfan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:241: UserWarning: Couldn't retrieve source code for container of type ClassifierModule. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    }
   ],
   "source": [
    "with open(file_name, 'wb') as f:\n",
    "    pickle.dump(net, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(file_name, 'rb') as f:\n",
    "    new_net = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving only the model parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This only saves and loads the proper `module` parameters, meaning that hyperparameters such as `lr` and `max_epochs` are not saved. Therefore, to load the model, we have to re-initialize it beforehand."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.save_params(f_params=file_name)  # a file handler also works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first initialize the model\n",
    "new_net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    ").initialize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_net.load_params(file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage with an `sklearn Pipeline`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to put the `NeuralNetClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = Pipeline([\n",
    "    ('scale', StandardScaler()),\n",
    "    ('net', net),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Re-initializing module!\n",
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.7243\u001b[0m       \u001b[32m0.5000\u001b[0m        \u001b[35m0.7105\u001b[0m  0.0184\n",
      "      2        \u001b[36m0.7057\u001b[0m       0.5000        \u001b[35m0.6996\u001b[0m  0.0207\n",
      "      3        \u001b[36m0.6971\u001b[0m       0.5000        \u001b[35m0.6949\u001b[0m  0.0192\n",
      "      4        \u001b[36m0.6936\u001b[0m       \u001b[32m0.5050\u001b[0m        \u001b[35m0.6929\u001b[0m  0.0224\n",
      "      5        \u001b[36m0.6923\u001b[0m       \u001b[32m0.5400\u001b[0m        \u001b[35m0.6916\u001b[0m  0.0210\n",
      "      6        \u001b[36m0.6905\u001b[0m       0.5000        \u001b[35m0.6906\u001b[0m  0.0189\n",
      "      7        \u001b[36m0.6894\u001b[0m       0.5100        \u001b[35m0.6899\u001b[0m  0.0194\n",
      "      8        \u001b[36m0.6891\u001b[0m       0.5150        \u001b[35m0.6892\u001b[0m  0.0186\n",
      "      9        0.6899       0.5250        \u001b[35m0.6885\u001b[0m  0.0202\n",
      "     10        \u001b[36m0.6844\u001b[0m       0.5300        \u001b[35m0.6876\u001b[0m  0.0189\n",
      "     11        0.6853       \u001b[32m0.5650\u001b[0m        \u001b[35m0.6865\u001b[0m  0.0199\n",
      "     12        \u001b[36m0.6842\u001b[0m       \u001b[32m0.5700\u001b[0m        \u001b[35m0.6855\u001b[0m  0.0183\n",
      "     13        \u001b[36m0.6821\u001b[0m       \u001b[32m0.5850\u001b[0m        \u001b[35m0.6844\u001b[0m  0.0199\n",
      "     14        \u001b[36m0.6821\u001b[0m       \u001b[32m0.6050\u001b[0m        \u001b[35m0.6832\u001b[0m  0.0189\n",
      "     15        \u001b[36m0.6820\u001b[0m       \u001b[32m0.6100\u001b[0m        \u001b[35m0.6820\u001b[0m  0.0206\n",
      "     16        \u001b[36m0.6769\u001b[0m       0.6100        \u001b[35m0.6800\u001b[0m  0.0188\n",
      "     17        0.6784       \u001b[32m0.6200\u001b[0m        \u001b[35m0.6780\u001b[0m  0.0219\n",
      "     18        \u001b[36m0.6763\u001b[0m       \u001b[32m0.6450\u001b[0m        \u001b[35m0.6761\u001b[0m  0.0233\n",
      "     19        \u001b[36m0.6704\u001b[0m       \u001b[32m0.6550\u001b[0m        \u001b[35m0.6729\u001b[0m  0.0254\n",
      "     20        \u001b[36m0.6691\u001b[0m       \u001b[32m0.6750\u001b[0m        \u001b[35m0.6699\u001b[0m  0.0252\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('scale', StandardScaler(copy=True, with_mean=True, with_std=True)), ('net', <class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       "))])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.5064775 , 0.49352255],\n",
       "       [0.53243965, 0.46756038],\n",
       "       [0.57306874, 0.42693123],\n",
       "       [0.54179883, 0.45820117],\n",
       "       [0.5528906 , 0.44710937]], dtype=float32)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_proba = pipe.predict_proba(X[:5])\n",
    "y_proba"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To save the whole pipeline, including the pytorch module, use `pickle`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adding a new callback to the model is straightforward. Below we show how to add a new callback that determines the area under the ROC (AUC) score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.callbacks import EpochScoring"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is a scoring callback in skorch, `EpochScoring`, which we use for this. We have to specify which score to calculate. We have 3 choices:\n",
    "\n",
    "* Passing a string: This should be a valid `sklearn` metric. For a list of all existing scores, look [here](http://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics).\n",
    "* Passing `None`: If you implement your own `.score` method on your neural net, passing `scoring=None` will tell `skorch` to use that.\n",
    "* Passing a function or callable: If we want to define our own scoring function, we pass a function with the signature `func(model, X, y) -> score`, which is then used.\n",
    "\n",
    "Note that this works exactly the same as scoring in `sklearn` does."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our case here, since `sklearn` already implements AUC, we just pass the correct string `'roc_auc'`. We should also tell the callback that higher scores are better (to get the correct colors printed below -- by default, lower scores are assumed to be better). Furthermore, we may specify a `name` argument for `EpochScoring`, and whether to use training data (by setting `on_train=True`) or validation data (which is the default)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "auc = EpochScoring(scoring='roc_auc', lower_is_better=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we pass the scoring callback to the `callbacks` parameter as a list and then call `fit`. Notice that we get the printed scores and color highlighting for free."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    callbacks=[auc],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    roc_auc    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ---------  ------------  -----------  ------------  ------\n",
      "      1     \u001b[36m0.6112\u001b[0m        \u001b[32m0.7076\u001b[0m       \u001b[35m0.5550\u001b[0m        \u001b[31m0.6802\u001b[0m  0.0188\n",
      "      2     \u001b[36m0.6766\u001b[0m        \u001b[32m0.6750\u001b[0m       \u001b[35m0.6150\u001b[0m        \u001b[31m0.6626\u001b[0m  0.0204\n",
      "      3     \u001b[36m0.7031\u001b[0m        \u001b[32m0.6560\u001b[0m       \u001b[35m0.6500\u001b[0m        \u001b[31m0.6498\u001b[0m  0.0244\n",
      "      4     \u001b[36m0.7201\u001b[0m        \u001b[32m0.6364\u001b[0m       \u001b[35m0.6650\u001b[0m        \u001b[31m0.6381\u001b[0m  0.0193\n",
      "      5     \u001b[36m0.7316\u001b[0m        \u001b[32m0.6176\u001b[0m       \u001b[35m0.6900\u001b[0m        \u001b[31m0.6285\u001b[0m  0.0203\n",
      "      6     \u001b[36m0.7447\u001b[0m        \u001b[32m0.6094\u001b[0m       \u001b[35m0.7200\u001b[0m        \u001b[31m0.6183\u001b[0m  0.0222\n",
      "      7     \u001b[36m0.7522\u001b[0m        0.6170       0.7200        \u001b[31m0.6090\u001b[0m  0.0188\n",
      "      8     \u001b[36m0.7567\u001b[0m        \u001b[32m0.5786\u001b[0m       0.7150        \u001b[31m0.6032\u001b[0m  0.0197\n",
      "      9     \u001b[36m0.7630\u001b[0m        0.5850       0.7100        \u001b[31m0.5954\u001b[0m  0.0214\n",
      "     10     \u001b[36m0.7706\u001b[0m        \u001b[32m0.5770\u001b[0m       0.7200        \u001b[31m0.5889\u001b[0m  0.0207\n",
      "     11     \u001b[36m0.7735\u001b[0m        \u001b[32m0.5740\u001b[0m       0.7050        \u001b[31m0.5842\u001b[0m  0.0188\n",
      "     12     0.7729        0.5771       0.7100        0.5859  0.0186\n",
      "     13     \u001b[36m0.7792\u001b[0m        \u001b[32m0.5557\u001b[0m       0.7000        \u001b[31m0.5745\u001b[0m  0.0178\n",
      "     14     \u001b[36m0.7825\u001b[0m        0.5810       0.7050        \u001b[31m0.5691\u001b[0m  0.0204\n",
      "     15     0.7824        0.5634       0.7200        \u001b[31m0.5691\u001b[0m  0.0194\n",
      "     16     0.7817        0.5778       0.7150        0.5704  0.0205\n",
      "     17     \u001b[36m0.7871\u001b[0m        0.5624       0.7150        \u001b[31m0.5633\u001b[0m  0.0188\n",
      "     18     0.7855        0.5613       0.7200        0.5660  0.0202\n",
      "     19     0.7792        0.5637       \u001b[35m0.7250\u001b[0m        0.5722  0.0194\n",
      "     20     0.7823        \u001b[32m0.5516\u001b[0m       0.7150        0.5681  0.0184\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=ClassifierModule(\n",
       "    (dense0): Linear(in_features=20, out_features=10, bias=True)\n",
       "    (dropout): Dropout(p=0.5)\n",
       "    (dense1): Linear(in_features=10, out_features=10, bias=True)\n",
       "    (output): Linear(in_features=10, out_features=2, bias=True)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For information on how to write custom callbacks, have a look at the [Advanced_Usage](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb) notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage with sklearn `GridSearchCV`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Special prefixes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. So e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. This is exactly the same logic that allows to access estimator parameters in `sklearn Pipeline`s and `FeatureUnion`s."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an `sklearn GridSearchCV` as shown below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module, iterator_train, iterator_valid, optimizer, criterion, callbacks, dataset\n"
     ]
    }
   ],
   "source": [
    "print(', '.join(net.prefixes_))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Performing a grid search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we show how to perform a grid search over the learning rate (`lr`), the module's number of hidden units (`module__num_units`), the module's dropout rate (`module__dropout`), and whether the SGD optimizer should use Nesterov momentum or not (`optimizer__nesterov`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=20,\n",
    "    lr=0.1,\n",
    "    verbose=0,\n",
    "    optimizer__momentum=0.9,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'lr': [0.05, 0.1],\n",
    "    'module__num_units': [10, 20],\n",
    "    'module__dropout': [0, 0.5],\n",
    "    'optimizer__nesterov': [False, True],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.3s remaining:    0.0s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n",
      "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n",
      "[CV]  lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total=   0.3s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Done  48 out of  48 | elapsed:   15.7s finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=3, error_score='raise-deprecating',\n",
       "       estimator=<class 'skorch.classifier.NeuralNetClassifier'>[uninitialized](\n",
       "  module=<class '__main__.ClassifierModule'>,\n",
       "),\n",
       "       fit_params=None, iid='warn', n_jobs=None,\n",
       "       param_grid={'lr': [0.05, 0.1], 'module__num_units': [10, 20], 'module__dropout': [0, 0.5], 'optimizer__nesterov': [False, True]},\n",
       "       pre_dispatch='2*n_jobs', refit=False, return_train_score='warn',\n",
       "       scoring='accuracy', verbose=2)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gs.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.862 {'lr': 0.05, 'module__dropout': 0, 'module__num_units': 20, 'optimizer__nesterov': False}\n"
     ]
    }
   ],
   "source": [
    "print(gs.best_score_, gs.best_params_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Of course, we could further nest the `NeuralNetClassifier` within an `sklearn Pipeline`, in which case we just prefix the parameter by the name of the net (e.g. `net__module__num_units`)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
